import jax
import jax.numpy as jnp
import numpy as np
import pickle
import os
from typing import Dict, Any, Optional, Tuple
from torch import nn
import torch
import torch.nn.functional as F
import gc, jax

def save_model(model, filepath='sae_model.pkl'):
    """
    Save a SparseAutoencoder model to a file.
    
    Args:
        model: SparseAutoencoder model to save
        filepath: Path where the model will be saved
    """
    # Create a serializable dictionary with all model information
    model_dict = {
        'params': {
            # Convert JAX arrays to numpy for better compatibility
            'encoder': {
                'weights': np.array(model.params['encoder']['weights']),
                'bias': np.array(model.params['encoder']['bias'])
            },
            'decoder': {
                'weights': np.array(model.params['decoder']['weights']),
                'bias': np.array(model.params['decoder']['bias'])
            },
            'tied_bias': np.array(model.params['tied_bias'])
        },
        'k': model.k,
        'embed_dim': model.embed_dim,
        'hidden_dim': model.hidden_dim
    }
    
    # Save to file
    with open(filepath, 'wb') as f:
        pickle.dump(model_dict, f)
    
    print(f"Model saved to {filepath}")
    
    # Save a separate params-only file for easier loading in various frameworks
    params_only_path = os.path.splitext(filepath)[0] + "_params.npz"
    np.savez_compressed(
        params_only_path,
        encoder_weights=model.params['encoder']['weights'],
        encoder_bias=model.params['encoder']['bias'],
        decoder_weights=model.params['decoder']['weights'],
        decoder_bias=model.params['decoder']['bias'],
        tied_bias=model.params['tied_bias']
    )
    print(f"Model parameters also saved to {params_only_path}")

def load_model(filepath='sae_model.pkl', model_class=None):
    """
    Load a SparseAutoencoder model from a file.
    
    Args:
        filepath: Path to the saved model
        model_class: Optional class to use for reconstruction (default will use SparseAutoencoder)
        
    Returns:
        Loaded SparseAutoencoder model
    """
    # Import here to avoid circular imports
    if model_class is None:
        from sae_jax import SparseAutoencoder
        model_class = SparseAutoencoder
    
    # Load the serialized model
    with open(filepath, 'rb') as f:
        model_dict = pickle.load(f)
    
    # Convert numpy arrays back to JAX arrays
    params = {
        'encoder': {
            'weights': jnp.array(model_dict['params']['encoder']['weights']),
            'bias': jnp.array(model_dict['params']['encoder']['bias'])
        },
        'decoder': {
            'weights': jnp.array(model_dict['params']['decoder']['weights']),
            'bias': jnp.array(model_dict['params']['decoder']['bias'])
        },
        'tied_bias': jnp.array(model_dict['params']['tied_bias'])
    }
    
    # Reconstruct the model
    model = model_class(
        params=params,
        k=model_dict['k'],
        embed_dim=model_dict['embed_dim'],
        hidden_dim=model_dict['hidden_dim']
    )
    
    print(f"Model loaded from {filepath}")
    return model

def load_model_params_only(filepath='sae_model_params.npz'):
    """
    Load only the parameters of a SparseAutoencoder model.
    
    Args:
        filepath: Path to the saved parameters (.npz file)
        
    Returns:
        Dictionary of model parameters
    """
    data = np.load(filepath)
    
    params = {
        'encoder': {
            'weights': jnp.array(data['encoder_weights']),
            'bias': jnp.array(data['encoder_bias'])
        },
        'decoder': {
            'weights': jnp.array(data['decoder_weights']),
            'bias': jnp.array(data['decoder_bias'])
        },
        'tied_bias': jnp.array(data['tied_bias'])
    }
    
    print(f"Model parameters loaded from {filepath}")
    return params

def save_metadata(metadata, filepath='sae_metadata.pkl'):
    """
    Save training metadata and hyperparameters.
    
    Args:
        metadata: Dictionary with metadata to save
        filepath: Path where the metadata will be saved
    """
    # Convert any JAX arrays to numpy
    processed_metadata = {}
    for key, value in metadata.items():
        if hasattr(value, 'dtype') and hasattr(value, 'shape'):
            # Likely a JAX array
            processed_metadata[key] = np.array(value)
        else:
            processed_metadata[key] = value
    
    with open(filepath, 'wb') as f:
        pickle.dump(processed_metadata, f)
    
    print(f"Metadata saved to {filepath}")

# Checkpoint saving during training
def save_checkpoint(model, optimizer_state, step, filepath_prefix='sae_checkpoint'):
    """
    Save a training checkpoint.
    
    Args:
        model: SparseAutoencoder model
        optimizer_state: State of the optimizer
        step: Current training step
        filepath_prefix: Prefix for the checkpoint files
    """
    filepath = f"{filepath_prefix}_step{step}.pkl"
    
    checkpoint = {
        'model': {
            'params': {k: {sk: np.array(v) for sk, v in v.items()} 
                       if isinstance(v, dict) else np.array(v) 
                       for k, v in model.params.items()},
            'k': model.k,
            'embed_dim': model.embed_dim,
            'hidden_dim': model.hidden_dim
        },
        'optimizer_state': optimizer_state,
        'step': step
    }
    
    with open(filepath, 'wb') as f:
        pickle.dump(checkpoint, f)
    
    print(f"Checkpoint saved at step {step} to {filepath}")

def load_checkpoint(filepath, optimizer=None, model_class=None):
    """
    Load a training checkpoint.
    
    Args:
        filepath: Path to the checkpoint file
        optimizer: Optional optimizer to reinitialize state
        model_class: Optional class to use for model reconstruction
    
    Returns:
        Tuple of (model, optimizer_state, step)
    """
    # Import here to avoid circular imports
    if model_class is None:
        from sae_jax import SparseAutoencoder
        model_class = SparseAutoencoder
    
    with open(filepath, 'rb') as f:
        checkpoint = pickle.load(f)
    
    # Convert numpy arrays back to JAX arrays for the model
    params = {
        k: {sk: jnp.array(v) for sk, v in v.items()} 
        if isinstance(v, dict) else jnp.array(v)
        for k, v in checkpoint['model']['params'].items()
    }
    
    # Reconstruct the model
    model = model_class(
        params=params,
        k=checkpoint['model']['k'],
        embed_dim=checkpoint['model']['embed_dim'],
        hidden_dim=checkpoint['model']['hidden_dim']
    )
    
    # Get optimizer state and step
    optimizer_state = checkpoint['optimizer_state']
    step = checkpoint['step']
    
    print(f"Checkpoint loaded from {filepath} at step {step}")
    return model, optimizer_state, step


class TopKActivation(nn.Module):
    def __init__(self, k: int):
        super().__init__()
        self.k = k
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0] if len(x.shape) > 1 else 1
        feat_size = x.shape[-1]
        
        x_reshaped = x.reshape(-1, feat_size)
        topk_values, topk_indices = torch.topk(x_reshaped, k=self.k, dim=-1)
        mask = torch.zeros_like(x_reshaped)
        mask.scatter_(1, topk_indices, 1.0)
        
        result = x_reshaped * mask
        return result.reshape(x.shape)

# Sparse Autoencoder - PyTorch Implementation
class SparseAutoencoder(nn.Module):
    def __init__(self, embed_dim: int, hidden_dim: int, k: int, bias_init: float = 0.0):
        super().__init__()
        self.k = k
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        
        # Encoder components
        self.encoder_weight = nn.Parameter(torch.empty(embed_dim, hidden_dim))
        self.encoder_bias = nn.Parameter(torch.zeros(hidden_dim))
        
        # Decoder components
        self.decoder_weight = nn.Parameter(torch.empty(hidden_dim, embed_dim))
        self.decoder_bias = nn.Parameter(torch.zeros(embed_dim))
        
        # Tied bias
        self.tied_bias = nn.Parameter(torch.full((embed_dim,), bias_init))
        
        # Initialize weights (similar to JAX implementation)
        nn.init.normal_(self.encoder_weight, std=0.02)
        nn.init.normal_(self.decoder_weight, std=0.02)
        
        # TopK activation
        self.topk_activation = TopKActivation(k=k)
    
    def encode(self, x: torch.Tensor):
        """Encode only, for use in training the MLP model."""
        x_minus_bias = x - self.tied_bias
        encoded = F.linear(x_minus_bias, self.encoder_weight.t(), self.encoder_bias)
        activated = self.topk_activation(encoded)
        return activated
    
    def forward(self, x: torch.Tensor, return_intermediates: bool = False):
        # Apply the negative of the tied bias
        x_minus_bias = x - self.tied_bias
        
        # Encode
        encoded = F.linear(x_minus_bias, self.encoder_weight.t(), self.encoder_bias)
        
        # Apply TopK activation
        activated = self.topk_activation(encoded)
        
        # Decode
        decoded = F.linear(activated, self.decoder_weight.t(), self.decoder_bias)
        
        # Add tied bias back
        output = decoded + self.tied_bias
        
        if return_intermediates:
            return output, {"pre_activation": encoded, "post_activation": activated}
        return output

# 2) Instantiate your PyTorch model (make sure bias_init matches what you used in JAX)
def load_jax_sae_to_pytorch(sae_model_path, pt_model=None, load_weights=True):
    """
    Load parameters from a JAX Sparse Autoencoder model into a PyTorch model.
    
    Args:
        sae_model_path: Path to the saved JAX model
        pt_model: Optional PyTorch model. If None, a new one will be created.
        load_weights: If True, loads weights from JAX model into PyTorch model. If False, 
                     only creates/returns PyTorch model with same architecture.
        
    Returns:
        The PyTorch model (with or without loaded parameters depending on load_weights)
    """
    sae_jax = load_model(sae_model_path)

    if pt_model is None:
        pt_model = SparseAutoencoder(
            embed_dim=sae_jax.embed_dim,
            hidden_dim=sae_jax.hidden_dim,
            k=sae_jax.k,
            bias_init=0.0
        )
    
    if load_weights:
        jax_params = sae_jax.params
        
        # Copy encoder weights & bias
        w_enc = np.array(jax_params["encoder"]["weights"])    # shape (embed_dim, hidden_dim)
        b_enc = np.array(jax_params["encoder"]["bias"])       # shape (hidden_dim,)
        pt_model.encoder_weight.data.copy_(torch.from_numpy(w_enc).float())
        pt_model.encoder_bias.data.copy_(torch.from_numpy(b_enc).float())
        
        # Copy decoder weights & bias
        w_dec = np.array(jax_params["decoder"]["weights"])    # shape (hidden_dim, embed_dim)
        b_dec = np.array(jax_params["decoder"]["bias"])       # shape (embed_dim,)
        pt_model.decoder_weight.data.copy_(torch.from_numpy(w_dec).float())
        pt_model.decoder_bias.data.copy_(torch.from_numpy(b_dec).float())
        
        # Copy tied bias
        tb = np.array(jax_params["tied_bias"])                # shape (embed_dim,)
        pt_model.tied_bias.data.copy_(torch.from_numpy(tb).float())
        
    # force Python garbage collection
    gc.collect()

    # clear JAX's compilation/staging caches
    jax.clear_caches()

    return pt_model

def encode_sparse_torch(model, x, batch_size=None):
    """
    Efficiently convert inputs to sparse codes using PyTorch.
    
    Args:
        model: PyTorch SparseAutoencoder model
        x: Input data (numpy array or torch tensor)
        batch_size: Optional batch size for processing large inputs
        
    Returns:
        sparse_codes: Sparse activations after top-k
    """
    import torch
    import numpy as np
    
    # Convert input to torch tensor if needed
    if not isinstance(x, torch.Tensor):
        # Create a copy of the array to handle read-only inputs
        if isinstance(x, np.ndarray):
            x_copy = np.array(x, copy=True)
        else:
            x_copy = np.asarray(x)
        x = torch.tensor(x_copy, dtype=torch.float32)
    
    # Get device from model
    device = next(model.parameters()).device
    
    # Process in batches if needed
    if batch_size is not None and x.shape[0] > batch_size:
        num_samples = x.shape[0]
        sparse_codes_list = []
        
        for i in range(0, num_samples, batch_size):
            batch = x[i:min(i+batch_size, num_samples)]
            batch = batch.to(device)
            with torch.no_grad():
                # Apply the negative of the tied bias
                batch_minus_bias = batch - model.tied_bias
                
                # Encode
                encoded = F.linear(batch_minus_bias, model.encoder_weight.t(), model.encoder_bias)
                
                # Compute top-k mask
                sorted_latents = -torch.sort(-torch.abs(encoded), dim=-1)[0]
                k_th_largest = sorted_latents[..., model.k - 1].unsqueeze(-1)
                topk_mask = torch.abs(encoded) >= k_th_largest
                
                # Apply the top-k mask to get sparse codes
                batch_sparse_codes = torch.where(topk_mask, encoded, torch.zeros_like(encoded))
                
            sparse_codes_list.append(batch_sparse_codes)
            
        return torch.cat(sparse_codes_list, dim=0)
    else:
        # Process all at once
        x = x.to(device)
        with torch.no_grad():
            # Apply the negative of the tied bias
            x_minus_bias = x - model.tied_bias
            
            # Encode
            encoded = F.linear(x_minus_bias, model.encoder_weight.t(), model.encoder_bias)
            
            # Compute top-k mask
            sorted_latents = -torch.sort(-torch.abs(encoded), dim=-1)[0]
            k_th_largest = sorted_latents[..., model.k - 1].unsqueeze(-1)
            topk_mask = torch.abs(encoded) >= k_th_largest
            
            # Apply the top-k mask to get sparse codes
            sparse_codes = torch.where(topk_mask, encoded, torch.zeros_like(encoded))
            
            return sparse_codes